Libraries

library(fastbaps)
library(rhierbaps)
library(ggtree)
library(phytools)
library(ggtree)
library(ggplot2)
library(data.table)
library(patchwork)
library(stringr)
library(dplyr)
library(Matrix)
library(ggbeeswarm)

set.seed(1234)

Some plotting functions

plot_custom_ggtree_snap <- function(tree, multi.baps, multi.hc, multi.flat, 
    hb, bbp, snap) {
    colnames(multi.baps)[2:ncol(multi.baps)] <- paste("baps", colnames(multi.baps)[2:ncol(multi.baps)], 
        sep = "_")
    colnames(multi.hc)[2:ncol(multi.hc)] <- paste("hc", colnames(multi.hc)[2:ncol(multi.hc)], 
        sep = "_")
    colnames(multi.flat)[2:ncol(multi.flat)] <- paste("flat", colnames(multi.flat)[2:ncol(multi.flat)], 
        sep = "_")
    
    multi <- merge(multi.baps, multi.hc, by.x = "Isolates", by.y = "Isolates")
    multi <- merge(multi, multi.flat, by.x = "Isolates", by.y = "Isolates")
    
    multi$bbp.tree <- bbp[match(multi$Isolates, names(bbp))]
    multi$HB1 <- hb$`level 1`[match(multi$Isolates, hb$Isolate)]
    multi$HB2 <- hb$`level 2`[match(multi$Isolates, hb$Isolate)]
    multi$snap <- snap[match(multi$Isolates, names(snap))]
    
    
    gg <- ggtree(tree)  #+ geom_text2(aes(subset=!isTip, label=label), hjust=-.1)
    f1 <- facet_plot(gg, panel = "hierBAPS L1", data = multi, geom = geom_tile, 
        aes(x = HB1), fill = "#e41a1c")
    f1 <- facet_plot(f1, panel = "hierBAPS L2", data = multi, geom = geom_tile, 
        aes(x = HB2), fill = "#377eb8")
    f1 <- facet_plot(f1, panel = "Level 1 - BAPS prior", data = multi, geom = geom_tile, 
        aes(x = `baps_Level 1`), fill = "#4daf4a")
    f1 <- facet_plot(f1, panel = "Level 2 - BAPS prior", data = multi, geom = geom_tile, 
        aes(x = `baps_Level 2`), fill = "#984ea3")
    f1 <- facet_plot(f1, panel = "Level 1 - HC prior", data = multi, geom = geom_tile, 
        aes(x = `hc_Level 1`), fill = "#ff7f00")
    f1 <- facet_plot(f1, panel = "snapclust AIC", data = multi, geom = geom_tile, 
        aes(x = snap), fill = "#f781bf")
    f1 <- facet_plot(f1, panel = "Level 1 - opt BAPS prior", data = multi, geom = geom_tile, 
        aes(x = `flat_Level 1`), fill = "#b2df8a")
    f1 <- facet_plot(f1, panel = "Level 2 - opt BAPS prior", data = multi, geom = geom_tile, 
        aes(x = `flat_Level 2`), fill = "#cab2d6")
    f1 <- facet_plot(f1, panel = "bbp.tree", data = multi, geom = geom_tile, 
        aes(x = bbp.tree), fill = "#a65628")
    f1 <- f1 + theme(text = element_text(size = 15))
    return(f1)
}


plot_custom_ggtree_snap_2 <- function(tree, multi.baps, multi.hc, multi.flat, 
    hb, bbp, snap) {
    colnames(multi.baps)[2:ncol(multi.baps)] <- paste("baps", colnames(multi.baps)[2:ncol(multi.baps)], 
        sep = "_")
    colnames(multi.hc)[2:ncol(multi.hc)] <- paste("hc", colnames(multi.hc)[2:ncol(multi.hc)], 
        sep = "_")
    colnames(multi.flat)[2:ncol(multi.flat)] <- paste("flat", colnames(multi.flat)[2:ncol(multi.flat)], 
        sep = "_")
    
    multi <- merge(multi.baps, multi.hc, by.x = "Isolates", by.y = "Isolates")
    multi <- merge(multi, multi.flat, by.x = "Isolates", by.y = "Isolates")
    
    multi$bbp.tree <- bbp[match(multi$Isolates, names(bbp))]
    multi$HB1 <- hb$`level 1`[match(multi$Isolates, hb$Isolate)]
    multi$HB2 <- hb$`level 2`[match(multi$Isolates, hb$Isolate)]
    multi$snap <- snap[match(multi$Isolates, names(snap))]
    
    
    gg <- ggtree(tree)  #+ geom_text2(aes(subset=!isTip, label=label), hjust=-.1)
    f1 <- facet_plot(gg, panel = "hierBAPS", data = multi, geom = geom_tile, 
        aes(x = HB1), fill = "#e41a1c")
    f1 <- facet_plot(f1, panel = "snapclust AIC", data = multi, geom = geom_tile, 
        aes(x = snap), fill = "#f781bf")
    f1 <- facet_plot(f1, panel = "BHC prior", data = multi, geom = geom_tile, 
        aes(x = `hc_Level 1`), fill = "#ff7f00")
    f1 <- facet_plot(f1, panel = "optimised BAPs prior", data = multi, geom = geom_tile, 
        aes(x = `flat_Level 1`), fill = "#b2df8a")
    f1 <- facet_plot(f1, panel = "fastbaps partition", data = multi, geom = geom_tile, 
        aes(x = bbp.tree), fill = "#a65628")
    f1 <- f1 + theme(text = element_text(size = 17))
    return(f1)
}

plot_custom_ggtree_snap_3 <- function(tree, multi.baps, multi.hc, multi.flat, 
    hb, bbp, snap) {
    colnames(multi.baps)[2:ncol(multi.baps)] <- paste("baps", colnames(multi.baps)[2:ncol(multi.baps)], 
        sep = "_")
    colnames(multi.hc)[2:ncol(multi.hc)] <- paste("hc", colnames(multi.hc)[2:ncol(multi.hc)], 
        sep = "_")
    colnames(multi.flat)[2:ncol(multi.flat)] <- paste("flat", colnames(multi.flat)[2:ncol(multi.flat)], 
        sep = "_")
    
    multi <- merge(multi.baps, multi.hc, by.x = "Isolates", by.y = "Isolates")
    multi <- merge(multi, multi.flat, by.x = "Isolates", by.y = "Isolates")
    
    multi$bbp.tree <- bbp[match(multi$Isolates, names(bbp))]
    multi$HB1 <- hb$`level 1`[match(multi$Isolates, hb$Isolate)]
    multi$HB2 <- hb$`level 2`[match(multi$Isolates, hb$Isolate)]
    multi$snap <- snap[match(multi$Isolates, names(snap))]
    
    
    gg <- ggtree(tree)  #+ geom_text2(aes(subset=!isTip, label=label), hjust=-.1)
    f1 <- facet_plot(gg, panel = "hierBAPS", data = multi, geom = geom_tile, 
        aes(x = HB1), fill = "#e41a1c")
    f1 <- facet_plot(f1, panel = "snapclust AIC", data = multi, geom = geom_tile, 
        aes(x = snap), fill = "#f781bf")
    f1 <- facet_plot(f1, panel = "BHC prior", data = multi, geom = geom_tile, 
        aes(x = `hc_Level 1`), fill = "#ff7f00")
    f1 <- facet_plot(f1, panel = "optimised symmetric prior", data = multi, 
        geom = geom_tile, aes(x = `flat_Level 1`), fill = "#b2df8a")
    f1 <- facet_plot(f1, panel = "best BAPs partition fasttree", data = multi, 
        geom = geom_tile, aes(x = bbp.tree), fill = "#a65628")
    f1 <- f1 + theme(text = element_text(size = 15))
    return(f1)
}

Escherichia Coli

e.coli.data <- import_fasta_sparse_nt("./data/escherichia_coli/escherichia_coli_core_gene_alignment.snps.aln", 
    prior = "baps")

dim(e.coli.data$snp.matrix)
## [1] 178117   1508
e.coli.multi.baps <- multi_res_baps(e.coli.data, levels = 2, n.cores = 4)

e.coli.data <- optimise_prior(e.coli.data, type = "hc", n.cores = 4)
## [1] "Optimised hyperparameter: 0.134"
e.coli.multi.hc <- multi_res_baps(e.coli.data, levels = 2, n.cores = 4)

e.coli.data <- optimise_prior(e.coli.data, type = "optimise.baps", n.cores = 4)
## [1] "Optimised hyperparameter: 0.063"
e.coli.multi.flat <- multi_res_baps(e.coli.data, levels = 2, n.cores = 4)

e.coli.hb <- fread("./data/escherichia_coli/hierbaps_partition.csv", data.table = FALSE)
e.coli.hb <- e.coli.hb[match(e.coli.multi.baps$Isolates, e.coli.hb$Isolate), 
    ]

e.coli.snap <- fread("./data/escherichia_coli_core_gene_alignment.snps_snapclust_results.csv", 
    data.table = FALSE)
e.coli.sum <- fread("./data/escherichia_coli_core_gene_alignment.snps_snapclust_run_summary.csv", 
    data.table = FALSE)
e.coli.snap.aic <- e.coli.snap[, e.coli.sum$K[which.min(e.coli.sum$AIC)]]
names(e.coli.snap.aic) <- e.coli.snap$Isolate

e.coli.tree <- phytools::read.newick("./data/escherichia_coli/escherichia_coli_core_gene_alignment.snps.aln.fasttree")

e.coli.tree <- phytools::midpoint.root(e.coli.tree)
e.coli.bbp.tree <- fastbaps::best_baps_partition(e.coli.data, e.coli.tree)
## [1] "Calculating node marginal llks..."
## [1] "Finding best partition..."
plot_custom_ggtree_snap(tree = e.coli.tree, multi.baps = e.coli.multi.baps, 
    multi.hc = e.coli.multi.hc, multi.flat = e.coli.multi.flat, hb = e.coli.hb, 
    bbp = e.coli.bbp.tree, snap = e.coli.snap.aic)

Haemophilus Influenzae

h.inf.data <- import_fasta_sparse_nt("./data/haemophilus_influenzae/haemophilus_influenzae_core_gene_snps.aln", 
    prior = "baps")

dim(h.inf.data$snp.matrix)
## [1] 103850     75
h.inf.multi.baps <- multi_res_baps(h.inf.data, levels = 2, n.cores = 4)

h.inf.data <- optimise_prior(h.inf.data, type = "hc", n.cores = 4)
## [1] "Optimised hyperparameter: 0.08"
h.inf.multi.hc <- multi_res_baps(h.inf.data, levels = 2, n.cores = 4)

h.inf.data <- optimise_prior(h.inf.data, type = "optimise.baps", n.cores = 4)
## [1] "Optimised hyperparameter: 0.197"
h.inf.multi.flat <- multi_res_baps(h.inf.data, levels = 2, n.cores = 4)

h.inf.hb <- fread("./data/haemophilus_influenzae_core_gene_snps_hierbaps_partition.csv", 
    data.table = FALSE)
h.inf.hb <- h.inf.hb[match(h.inf.multi.baps$Isolates, h.inf.hb$Isolate), ]

h.inf.snap <- fread("./data/haemophilus_influenzae/haemophilus_influenzae_core_gene_snps_snapclust_results.csv", 
    data.table = FALSE)
h.inf.sum <- fread("./data/haemophilus_influenzae/haemophilus_influenzae_core_gene_snps_snapclust_run_summary.csv", 
    data.table = FALSE)
h.inf.snap.aic <- h.inf.snap[, h.inf.sum$K[which.min(h.inf.sum$AIC)]]
names(h.inf.snap.aic) <- h.inf.snap$Isolate

h.inf.tree <- phytools::read.newick("./data/haemophilus_influenzae/haemophilus_influenzae_core_gene_snps.aln.fasttree")

h.inf.tree <- phytools::midpoint.root(h.inf.tree)
h.inf.bbp.tree <- fastbaps::best_baps_partition(h.inf.data, h.inf.tree)
## [1] "Calculating node marginal llks..."
## [1] "Finding best partition..."
plot_custom_ggtree_snap(tree = h.inf.tree, multi.baps = h.inf.multi.baps, multi.hc = h.inf.multi.hc, 
    multi.flat = h.inf.multi.flat, hb = h.inf.hb, bbp = h.inf.bbp.tree, snap = h.inf.snap.aic)

Plots for publication

plot_custom_ggtree_snap_2(tree = h.inf.tree, multi.baps = h.inf.multi.baps, 
    multi.hc = h.inf.multi.hc, multi.flat = h.inf.multi.flat, hb = h.inf.hb, 
    bbp = h.inf.bbp.tree, snap = h.inf.snap.aic)

Listeria Monocytogenes

listeria.data <- fastbaps::import_fasta_sparse_nt("./data/listeria_monocytogenes/listeria_monocytogenes_core_gene_snps.aln", 
    prior = "baps")
dim(listeria.data$snp.matrix)
## [1] 143582    128
listeria.multi.baps <- multi_res_baps(listeria.data, levels = 2, n.cores = 4)

listeria.data <- optimise_prior(listeria.data, type = "hc", n.cores = 4)
## [1] "Optimised hyperparameter: 0.081"
listeria.multi.hc <- multi_res_baps(listeria.data, levels = 2, n.cores = 4)

listeria.data <- optimise_prior(listeria.data, type = "optimise.baps", n.cores = 4)
## [1] "Optimised hyperparameter: 0.063"
listeria.multi.flat <- multi_res_baps(listeria.data, levels = 2, n.cores = 4)

listeria.hb <- fread("./data/listeria_monocytogenes_core_gene_snps_hierbaps_partition.csv", 
    data.table = FALSE)
listeria.hb <- listeria.hb[match(listeria.multi.hc$Isolates, listeria.hb$Isolate), 
    ]

listeria.snap <- fread("./data/listeria_monocytogenes_core_gene_snps_snapclust_results.csv", 
    data.table = FALSE)
listeria.sum <- fread("./data/listeria_monocytogenes_core_gene_snps_snapclust_run_summary.csv", 
    data.table = FALSE)
listeria.snap.aic <- listeria.snap[, listeria.sum$K[which.min(listeria.sum$AIC)]]
names(listeria.snap.aic) <- listeria.snap$Isolate

listeria.tree <- ape::read.tree("./data/listeria_monocytogenes/listeria_monocytogenes_core_gene_snps.aln.fasttree")

listeria.tree <- phytools::midpoint.root(listeria.tree)
listeria.bbp.tree <- fastbaps::best_baps_partition(listeria.data, listeria.tree)
## [1] "Calculating node marginal llks..."
## [1] "Finding best partition..."
gg <- ggtree(listeria.tree) + geom_text2(aes(subset = !isTip, label = label), 
    hjust = -0.1)
f1 <- facet_plot(gg, panel = "listeria.multi.flat", data = listeria.multi.flat, 
    geom = geom_tile, aes(x = `Level 1`), fill = "#e41a1c")
f1

plot_custom_ggtree_snap(tree = listeria.tree, multi.baps = listeria.multi.baps, 
    multi.hc = listeria.multi.hc, multi.flat = listeria.multi.flat, hb = listeria.hb, 
    bbp = listeria.bbp.tree, snap = listeria.snap.aic)

Neisseria Meningitidis

n.meng.data <- import_fasta_sparse_nt("./data/neisseria_meningitidis/neisseria_meningitidis_core_gene_snps.aln", 
    prior = "baps")
dim(n.meng.data$snp.matrix)
## [1] 80284   882
n.meng.multi.baps <- multi_res_baps(n.meng.data, levels = 2, n.cores = 4)

n.meng.data <- optimise_prior(n.meng.data, type = "hc", n.cores = 4)
## [1] "Optimised hyperparameter: 0.161"
n.meng.multi.hc <- multi_res_baps(n.meng.data, levels = 2, n.cores = 4)

n.meng.data <- optimise_prior(n.meng.data, type = "optimise.baps", n.cores = 4)
## [1] "Optimised hyperparameter: 0.107"
n.meng.multi.flat <- multi_res_baps(n.meng.data, levels = 2, n.cores = 4)


n.meng.hb <- fread("./data/neisseria_meningitidis_core_gene_snps_hierbaps_partition.csv", 
    data.table = FALSE)
n.meng.hb <- n.meng.hb[match(n.meng.multi.baps$Isolates, n.meng.hb$Isolate), 
    ]

n.meng.snap <- fread("./data/neisseria_meningitidis_core_gene_snps_snapclust_results.csv", 
    data.table = FALSE)
n.meng.sum <- fread("./data/neisseria_meningitidis_core_gene_snps_snapclust_run_summary.csv", 
    data.table = FALSE)
n.meng.snap.aic <- n.meng.snap[, n.meng.sum$K[which.min(n.meng.sum$AIC)]]
names(n.meng.snap.aic) <- n.meng.snap$Isolate

n.meng.tree <- phytools::read.newick("./data/neisseria_meningitidis/neisseria_meningitidis_core_gene_snps.aln.fasttree")

n.meng.tree <- phytools::midpoint.root(n.meng.tree)
n.meng.bbp.tree <- fastbaps::best_baps_partition(n.meng.data, n.meng.tree)
## [1] "Calculating node marginal llks..."
## [1] "Finding best partition..."
plot_custom_ggtree_snap(n.meng.tree, multi.baps = n.meng.multi.baps, multi.hc = n.meng.multi.hc, 
    multi.flat = n.meng.multi.flat, hb = n.meng.hb, bbp = n.meng.bbp.tree, snap = n.meng.snap.aic)

plot_custom_ggtree_snap_2(n.meng.tree, multi.baps = n.meng.multi.baps, multi.hc = n.meng.multi.hc, 
    multi.flat = n.meng.multi.flat, hb = n.meng.hb, bbp = n.meng.bbp.tree, snap = n.meng.snap.aic)

Staphylococcus Aureus

s.aus.data <- import_fasta_sparse_nt("./data/staphylococcus_aureus/staphylococcus_aureus_core_gene_alignment.snps.aln", 
    prior = "baps")
dim(s.aus.data$snp.matrix)
## [1] 37091   284
s.aus.multi.baps <- multi_res_baps(s.aus.data, levels = 2, n.cores = 4)

s.aus.data <- optimise_prior(s.aus.data, type = "hc", n.cores = 4)
## [1] "Optimised hyperparameter: 0.091"
s.aus.multi.hc <- multi_res_baps(s.aus.data, levels = 2, n.cores = 4)

s.aus.data <- optimise_prior(s.aus.data, type = "optimise.baps", n.cores = 4)
## [1] "Optimised hyperparameter: 0.061"
s.aus.multi.flat <- multi_res_baps(s.aus.data, levels = 2, n.cores = 4)

s.aus.hb <- fread("./data/staphylococcus_aureus_core_gene_alignment.snps_hierbaps_partition.csv", 
    data.table = FALSE)
s.aus.hb <- s.aus.hb[match(s.aus.multi.baps$Isolates, s.aus.hb$Isolate), ]

s.aus.tree <- phytools::read.newick("./data/staphylococcus_aureus/staphylococcus_aureus_core_gene_alignment.snps.aln.fasttree")

s.aus.snap <- fread("./data/staphylococcus_aureus_core_gene_alignment.snps_snapclust_results.csv", 
    data.table = FALSE)
s.aus.sum <- fread("./data/staphylococcus_aureus_core_gene_alignment.snps_snapclust_run_summary.csv", 
    data.table = FALSE)
s.aus.snap.aic <- s.aus.snap[, s.aus.sum$K[which.min(s.aus.sum$AIC)]]
names(s.aus.snap.aic) <- s.aus.snap$Isolate

s.aus.tree <- phytools::midpoint.root(s.aus.tree)
s.aus.bbp.tree <- fastbaps::best_baps_partition(s.aus.data, s.aus.tree)
## [1] "Calculating node marginal llks..."
## [1] "Finding best partition..."
plot_custom_ggtree_snap(s.aus.tree, multi.baps = s.aus.multi.baps, multi.flat = s.aus.multi.flat, 
    multi.hc = s.aus.multi.hc, hb = s.aus.hb, bbp = s.aus.bbp.tree, snap = s.aus.snap.aic)

plot_custom_ggtree_snap_2(s.aus.tree, multi.baps = s.aus.multi.baps, multi.flat = s.aus.multi.flat, 
    multi.hc = s.aus.multi.hc, hb = s.aus.hb, bbp = s.aus.bbp.tree, snap = s.aus.snap.aic)

Ebola

ebola.data <- import_fasta_sparse_nt("./data/ebola/Makona_1610_cds_ig.fas", 
    prior = "baps")

ebola.multi.baps <- multi_res_baps(ebola.data, levels = 2, n.cores = 4)

ebola.data <- optimise_prior(ebola.data, type = "hc", n.cores = 4)
## [1] "Optimised hyperparameter: 0.506"
ebola.multi.hc <- multi_res_baps(ebola.data, levels = 2, n.cores = 4)


ebola.data <- optimise_prior(ebola.data, type = "optimise.symmetric", n.cores = 4)
## [1] "Optimised hyperparameter: 0.007"
ebola.multi.flat <- multi_res_baps(ebola.data, levels = 2, n.cores = 4)

ebola.hb <- fread("./data/Makona_1610_cds_ig_hierbaps_partition.csv", data.table = FALSE)
ebola.hb <- ebola.hb[match(ebola.multi.baps$Isolates, ebola.hb$Isolate), ]

ebola.snap <- fread("./data/Makona_1610_cds_ig_snapclust_results.csv", data.table = FALSE)
ebola.sum <- fread("./data/Makona_1610_cds_ig_snapclust_run_summary.csv", data.table = FALSE)
ebola.snap.aic <- ebola.snap[, ebola.sum$K[which.min(ebola.sum$AIC)]]
names(ebola.snap.aic) <- ebola.snap$Isolate

ebola.tree <- phytools::read.newick("./data/ebola/Makona_1610_cds_ig_iqtree_allnni.treefile")

plot(ebola.tree, show.tip.label = FALSE)

ebola.tree <- phytools::midpoint.root(ebola.tree)
ebola.bbp.tree <- fastbaps::best_baps_partition(ebola.data, ebola.tree)
## [1] "Calculating node marginal llks..."
## [1] "Finding best partition..."
plot_custom_ggtree_snap(tree = ebola.tree, hb = ebola.hb, multi.baps = ebola.multi.baps, 
    multi.hc = ebola.multi.hc, multi.flat = ebola.multi.flat, bbp = ebola.bbp.tree, 
    snap = ebola.snap.aic)

plot_custom_ggtree_snap_3(tree = ebola.tree, hb = ebola.hb, multi.baps = ebola.multi.baps, 
    multi.hc = ebola.multi.hc, multi.flat = ebola.multi.flat, bbp = ebola.bbp.tree, 
    snap = ebola.snap.aic)

HIV

Load HIV data and remove outliers

hiv.data <- import_fasta_sparse_nt("./data/HIV/hiv_refs_prrt_trim.fas", prior = "baps")

cs <- colSums(hiv.data$snp.matrix > 0)
hist(cs, breaks = 100)

sum(cs > 200)
## [1] 86
hiv.data$snp.matrix <- hiv.data$snp.matrix[, cs < 200]
system.time({
    hiv.data <- optimise_prior(hiv.data, type = "optimise.symmetric", n.cores = 10)
})

system.time({
    hiv.multi <- multi_res_baps(hiv.data, levels = 1, n.cores = 10, k.init = 11000)
})

write.csv(hiv.multi, file = "./processed_data/hiv_multi_opt_baps_prior_l1.csv", 
    quote = FALSE)
system.time({
    hiv.data <- optimise_prior(hiv.data, type = "optimise.baps", n.cores = 10)
})

system.time({
    hiv.multi <- multi_res_baps(hiv.data, levels = 1, n.cores = 10, k.init = 11000)
})

write.csv(hiv.multi, file = "./processed_data/hiv_multi_opt_symmetric_prior_l1.csv", 
    quote = FALSE)
hiv.multi.opt.baps <- fread("./processed_data/hiv_multi_opt_baps_prior_l1.csv", 
    data.table = FALSE)

hiv.multi.opt.symmetric <- fread("./processed_data/hiv_multi_opt_symmetric_prior_l1.csv", 
    data.table = FALSE)

Choose a pretty silly outgroup as the earliest sample.

hiv.tree <- phytools::read.newick("./data/HIV/hiv_refs_prrt_trim.fas.fasttree")
bad.tips <- hiv.tree$tip.label[!(hiv.tree$tip.label %in% hiv.multi.opt.baps$Isolates)]
hiv.tree <- ape::drop.tip(hiv.tree, tip = bad.tips)

longest.edge <- hiv.tree$edge[order(-hiv.tree$edge.length)][1:100]
long.nodes <- hiv.tree$edge[longest.edge, ]  #this should usually include one tip
new.outgroup <- long.nodes[long.nodes < Ntip(hiv.tree)][[1]]
tree.rooted <- root(hiv.tree, outgroup = new.outgroup, resolve.root = T)
hiv.data <- optimise_prior(hiv.data, type = "optimise.baps")
best.baps.partition <- fastbaps::best_baps_partition(hiv.data, tree.rooted)

write.table(data.frame(Isolates = names(best.baps.partition), best.baps.partition = best.baps.partition, 
    stringsAsFactors = FALSE), file = "./processed_data/best_nbaps_partition_fasttree.csv", 
    sep = ",", col.names = TRUE, row.names = FALSE, quote = FALSE)
best.baps.partition <- fread("./processed_data/best_nbaps_partition_fasttree.csv", 
    data.table = FALSE)

hiv.multi.opt.symmetric <- hiv.multi.opt.symmetric[match(tree.rooted$tip.label, 
    hiv.multi.opt.symmetric$Isolates), ]

hiv.multi.opt.symmetric$V1 <- NULL
hiv.multi.opt.baps$V1 <- NULL

hiv.multi.merged <- merge(hiv.multi.opt.symmetric, hiv.multi.opt.baps, by.x = "Isolates", 
    by.y = "Isolates")
colnames(hiv.multi.merged) <- c("Isolate", "Level 1.sym", "Level 1.baps")

hiv.multi.merged <- merge(hiv.multi.merged, best.baps.partition, by.x = "Isolate", 
    by.y = "Isolates")
hiv.multi.merged <- hiv.multi.merged[match(tree.rooted$tip.label, hiv.multi.merged$Isolate), 
    ]

par(mfrow = c(1, 2))
plot(tree.rooted, show.tip.label = FALSE)
# tiplabels(pch=1, col = hiv.multi.opt.symmetric$`Level 1`)
plot(x = hiv.multi.merged$`Level 1.baps`, y = 1:nrow(hiv.multi.merged), pch = 0.1, 
    cex = 0.2, axes = FALSE)

par(mfrow = c(1, 1))

Lets compare to a simple kmeans with a similar number of clusters

first some helper function to find the elbow

distance <- function(p, start.point, end.point) {
    x1 <- start.point[[1]]
    y1 <- start.point[[2]]
    x2 <- end.point[[1]]
    y2 <- end.point[[2]]
    x0 <- p[[1]]
    y0 <- p[[2]]
    d <- abs((y2 - y1) * x0 - (x2 - x1) * y0 + x2 * y1 - y2 * x1)/sqrt((y2 - 
        y1)^2 + (x2 - x1)^2)
    return(d)
}

elbow <- function(sizes) {
    n <- length(sizes)
    x <- 1:n
    start.point <- c(1, sizes[[1]])
    end.point <- c(n, sizes[[n]])
    return(unlist(lapply(1:n, function(i) distance(c(i, sizes[[i]]), start.point, 
        end.point))))
}
set.seed(303)

pc <- irlba::prcomp_irlba(1 * t(hiv.data$snp.matrix > 0), n = 50)
# max(combined.df$best.baps.partition.tree)
all.k <- lapply(1:200, function(k) {
    k <- kmeans(pc$x, centers = k, iter.max = 100)
})

# look for the elbow ignoring k=1 as this seems very far from the remaining
# points
wss <- unlist(lapply(all.k, function(x) x$tot.withinss))
elbow.k <- which.max(elbow(wss[1:length(wss)]))
plot(wss)
print(elbow.k)
abline(v = elbow.k, col = "red")

elbow.k <- all.k[[elbow.k]]$cluster
names(elbow.k) <- colnames(hiv.data$snp.matrix)

hiv.multi.merged$kmeans.21 <- elbow.k[match(hiv.multi.merged$Isolate, names(elbow.k))]

temp.kmeans <- kmeans(pc$x, centers = 193, iter.max = 100)
hiv.multi.merged$kmeans.193 <- temp.kmeans$cluster[match(hiv.multi.merged$Isolate, 
    names(elbow.k))]

# Extract the type and country information from the isolate names

combined.df <- stringr::str_split_fixed(hiv.multi.merged$Isolate, pattern = "\\|", 
    n = 4)
colnames(combined.df) <- c("Isolate", "Type", "Country", "Date")
combined.df <- cbind(combined.df, hiv.multi.merged[, 2:ncol(hiv.multi.merged)])

write.table(combined.df, file = "./processed_data/combined_clusters.csv", sep = ",", 
    quote = FALSE, col.names = TRUE, row.names = FALSE)

Plot propotion of cluster that is it’s major type or contains it. I.e type ‘AE’ would match to both ‘A’ and ‘E’.

combined.df <- fread("./processed_data/combined_clusters.csv", data.table = FALSE)

combined.df <- combined.df[, c("Isolate", "Type", "Country", "Date", "Level 1.sym", 
    "Level 1.baps", "best.baps.partition", "kmeans.21", "kmeans.193")]
colnames(combined.df) <- c("Isolate", "Type", "Country", "Date", "Level 1 optimised symmetric prior", 
    "Level 1 optimised BAPs prior", "Best BAPs partition fasttree", "kmeans k=21", 
    "kmeans k=193")


all.method.clust.props <- lapply(5:ncol(combined.df), function(j) {
    cluster.types <- lapply(1:max(combined.df[, j]), function(x) {
        tb <- table(combined.df$Type[combined.df[, j] == x])
        tb[order(tb)]
    })
    names(cluster.types) <- 1:max(combined.df[, j])
    cluster.types <- cluster.types[order(unlist(lapply(cluster.types, sum)))]
    
    props <- unlist(lapply(cluster.types, function(c.table) {
        top <- names(c.table)[which.max(c.table)]
        top.caps <- unlist(strsplit(gsub("[^::A-Z::]", "", top), "*"))
        matching <- grepl(paste(top.caps, collapse = "|"), names(c.table))
        return(sum(c.table[matching])/sum(c.table))
    }))
    clust.sizes <- unlist(lapply(cluster.types, sum))
    
    cluster.stats <- data.frame(method = rep(colnames(combined.df)[j], length(props)), 
        cluster = names(props), prop.match.top = props, `cluster size` = clust.sizes, 
        stringsAsFactors = FALSE)
})
all.method.clust.props <- do.call(rbind, all.method.clust.props)
all.method.clust.props <- all.method.clust.props[!(all.method.clust.props$method %in% 
    c("kmeans k=21")), ]

gg1 <- ggplot(all.method.clust.props, aes(x = 1, y = prop.match.top)) + geom_violin(aes(fill = method)) + 
    geom_beeswarm(aes(size = cluster.size)) + facet_wrap(~method, ncol = 2) + 
    theme_bw() + ylab("Proportion comprising the dominant HIV subtype") + labs(size = "cluster size") + 
    theme(axis.title.x = element_blank(), axis.text.x = element_blank(), axis.ticks.x = element_blank()) + 
    theme(text = element_text(size = 20)) + theme(legend.position = "none")
gg1

Investigate dominant country proportions

all.method.clust.props <- lapply(5:ncol(combined.df), function(j) {
    cluster.types <- lapply(1:max(combined.df[, j]), function(x) {
        tb <- table(combined.df$Country[combined.df[, j] == x])
        tb[order(tb)]
    })
    names(cluster.types) <- 1:max(combined.df[, j])
    cluster.types <- cluster.types[order(unlist(lapply(cluster.types, sum)))]
    
    props <- unlist(lapply(cluster.types, function(c.table) {
        top <- names(c.table)[which.max(c.table)]
        matching <- top == names(c.table)
        return(sum(c.table[matching])/sum(c.table))
    }))
    clust.sizes <- unlist(lapply(cluster.types, sum))
    
    cluster.stats <- data.frame(method = rep(colnames(combined.df)[j], length(props)), 
        cluster = names(props), prop.match.top = props, `cluster size` = clust.sizes, 
        stringsAsFactors = FALSE)
})
all.method.clust.props <- do.call(rbind, all.method.clust.props)
all.method.clust.props <- all.method.clust.props[!(all.method.clust.props$method %in% 
    c("kmeans k=21")), ]

gg2 <- ggplot(all.method.clust.props, aes(x = 1, y = prop.match.top)) + geom_violin(aes(fill = method)) + 
    geom_beeswarm(aes(size = cluster.size)) + facet_wrap(~method, ncol = 2) + 
    theme_bw() + ylab("Proportion comprising the dominant country") + labs(size = "cluster size") + 
    theme(axis.title.x = element_blank(), axis.text.x = element_blank(), axis.ticks.x = element_blank()) + 
    theme(text = element_text(size = 20))
gg2

library(patchwork)
gg1 + gg2

umap plot

library(umap)
library(Matrix)

hiv.um <- umap::umap(pc$x, method = "umap-learn")

combined.df <- combined.df[match(colnames(hiv.data$snp.matrix), unlist(apply(combined.df[, 
    c(1, 2, 3, 4)], 1, paste, collapse = "|"))), ]

plot.df <- data.frame(Isolates = colnames(hiv.data$snp.matrix), hiv.um$layout, 
    k.means.193 = combined.df$`kmeans k=193`, k.means.21 = combined.df$`kmeans k=21`, 
    stringsAsFactors = FALSE)
plot.df <- merge(plot.df, hiv.multi.opt.baps, by.x = "Isolates", by.y = "Isolates")


palette(c("#ff7f00", "#cab2d6", "#6a3d9a", "#fdbf6f", "#ffff99", "#1f78b4", 
    "#b15928", "#33a02c", "#e31a1c", "#a6cee3", "#fb9a99", "#b2df8a"))
plot(X2 ~ X1, data = plot.df, col = plot.df$`Level 1`, cex = 0.01, xlab = "", 
    ylab = "", cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)

plot(X2 ~ X1, data = plot.df, col = k.means.193, cex = 0.1, xlab = "", ylab = "", 
    cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)

plot(X2 ~ X1, data = plot.df, col = k.means.21, cex = 0.1, xlab = "", ylab = "", 
    cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)

plot(X2 ~ X1, data = plot.df, col = "white", cex = 0.1, xlab = "", ylab = "", 
    cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)
text(x = plot.df$X1, y = plot.df$X2, labels = plot.df$`Level 1`, col = plot.df$`Level 1`, 
    cex = 0.3)

plot(X2 ~ X1, data = plot.df, col = "white", cex = 0.1, xlab = "", ylab = "", 
    cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)
text(x = plot.df$X1, y = plot.df$X2, labels = plot.df$k.means.193, col = plot.df$k.means.193, 
    cex = 0.3)

plot(X2 ~ X1, data = plot.df, col = "white", cex = 0.1, xlab = "", ylab = "", 
    cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)
text(x = plot.df$X1, y = plot.df$X2, labels = plot.df$k.means.21, col = plot.df$k.means.21, 
    cex = 0.3)

Investigate Maela

maela.data <- fastbaps::import_fasta_sparse_nt("./data/maela/maela3k_snps.fasta")
maela.data <- optimise_prior(maela.data, type = "optimise.baps")
## [1] "Optimised hyperparameter: 0.055"
system.time({
    maela.optimise.baps.multi <- multi_res_baps(maela.data, levels = 2, n.cores = 4)
})
##     user   system  elapsed 
## 2038.168   30.570  865.219
library(umap)
library(Matrix)

set.seed(824)

pc.maela <- irlba::prcomp_irlba(1 * t(maela.data$snp.matrix > 0), n = 50)
um.maela <- umap::umap(pc.maela$x, method = "umap-learn")
n.clusters <- max(maela.optimise.baps.multi$`Level 1`)

k.means.79 <- kmeans(pc.maela$x, centers = 79, iter.max = 50, nstart = 5)$cluster

k.max <- 200  # Maximal number of clusters
kmeans.results <- lapply(1:k.max, function(k) {
    kmeans(pc.maela$x, k, iter.max = 100)
})
wss <- unlist(lapply(kmeans.results, function(x) x$tot.withinss))

plot(1:k.max, wss, type = "b", pch = 19, frame = FALSE, xlab = "Number of clusters K", 
    ylab = "Total within-clusters sum of squares")

which.max(elbow(wss))
## [1] 46
k.means.48 <- kmeans(pc.maela$x, centers = 48, iter.max = 50, nstart = 5)$cluster


maela.plot.df <- data.frame(Isolates = colnames(maela.data$snp.matrix), um.maela$layout, 
    k.means.79 = k.means.79, k.means.48 = k.means.48, stringsAsFactors = FALSE)
maela.plot.df <- merge(maela.plot.df, maela.optimise.baps.multi, by.x = "Isolates", 
    by.y = "Isolates")


palette(c("#a6cee3", "#1f78b4", "#b2df8a", "#33a02c", "#fb9a99", "#e31a1c", 
    "#fdbf6f", "#ff7f00", "#cab2d6", "#6a3d9a", "#b15928"))
plot(X2 ~ X1, data = maela.plot.df, col = maela.plot.df$`Level 1`, cex = 1, 
    xlab = "", ylab = "", cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)

cols <- list(c(166, 206, 227), c(31, 120, 180), c(178, 223, 138), c(51, 160, 
    44), c(251, 154, 153), c(227, 26, 28), c(253, 191, 111), c(255, 127, 0), 
    c(202, 178, 214), c(106, 61, 154), c(177, 89, 40))
palette(unlist(lapply(cols, function(col) rgb(col[[1]], col[[2]], col[[3]], 
    alpha = 100, maxColorValue = 255))))

plot(X2 ~ X1, data = maela.plot.df, col = "white", cex = 0.1, xlab = "", ylab = "", 
    cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)
text(x = maela.plot.df$X1, y = maela.plot.df$X2, labels = maela.plot.df$`Level 1`, 
    col = maela.plot.df$`Level 1`, cex = 0.6)

plot(X2 ~ X1, data = maela.plot.df, col = "white", cex = 0.1, xlab = "", ylab = "", 
    cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)
text(x = maela.plot.df$X1, y = maela.plot.df$X2, labels = maela.plot.df$k.means.48, 
    col = maela.plot.df$k.means.48, cex = 0.6)

plot(X2 ~ X1, data = maela.plot.df, col = "white", cex = 0.1, xlab = "", ylab = "", 
    cex.axis = 1.5, font = 1.5)
mtext(side = 1, line = 3, "UMAP DIM 1", cex = 1.5)
mtext(side = 2, line = 3, "UMAP DIM 2", cex = 1.5)
text(x = maela.plot.df$X1, y = maela.plot.df$X2, labels = maela.plot.df$k.means.79, 
    col = maela.plot.df$k.means.79, cex = 0.6)